-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RMSE option to MSE code #249
Conversation
Hi @johannespitz, return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs) instead of return sum_squared_error / n_obs |
Should this be just a CompositeMetric(MSE, torch.sqrt)? |
Hi, unfortunately RMSE is not simply the square root of the MSE. Therefore, one can not compose the metric with some arithmetic. import torch
targets = torch.tensor([1.0, 2, 3, 4, -1, -2])
preditions = torch.tensor([0, 1, 5, 2, 0, -2])
lenght = targets.shape[0]
mse = ((targets - preditions) ** 2).sum()
rmse = (((targets - preditions) ** 2) ** 0.5).sum()
print(f"mse {mse / lenght}")
print(f"rmse {rmse / lenght}")
print(f"mse**0.5 {mse ** 0.5 / lenght}") The square root needs to be taken before summing up the individual errors. |
Is that actually true? Wikipedia gives RMSE = sqrt(MSE): https://en.wikipedia.org/wiki/Root-mean-square_deviation |
Oh, my bad. scikit-learn agrees with that definition. In case you want to pull the addition I'd suggest we call it the Average Euclidean Distance. |
I am actually in favour of adding this argument, as RMSE also is considered a standard machine learning algorithm. |
Sorry, about all those commits. But I have tested it, and now I'm getting the expected result. |
Codecov Report
@@ Coverage Diff @@
## master #249 +/- ##
==========================================
+ Coverage 96.81% 96.83% +0.01%
==========================================
Files 92 184 +92
Lines 3012 6026 +3014
==========================================
+ Hits 2916 5835 +2919
- Misses 96 191 +95
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
for more information, see https://pre-commit.ci
lgtm to me as well, but agree that the extra value might be slim compared to just I see your point about root of each observation. We are getting a few metrics like that recently and it might be worth trying to combine/generalize them. The patter is that we want to have
I do wonder if we will exhaust all of those or we should make a generic metric implementation that will compute average "distance" between pred and target for a pluggable distance. |
@maximsch2 I like the idea of a generic class which can handle all metrics that fit into this
This would also allow me to implement what I wanted (but didn't even implement as you pointed out above) |
What does this PR do?
Fixes #250 .
PR review
I'm open for any kind of feedback. This is a very simple fix and I think it might be helpful for others to have this functionality.